import time
import json
import os
import random
import re
import shutil
from functools import partial
from multiprocessing import Pool
from jinja2 import Template

import fire
import tqdm
from rouge_score import rouge_scorer

from src.util.io import read_jsonl, write_jsonl
from src.util.openai import openai_batch_completion, OpenAIDecodingArguments


NON_ALPHANUM_RE = re.compile(r"[^a-zа-яё0-9]+")


def tokenize(text):
    text = text.lower()
    text = NON_ALPHANUM_RE.sub(" ", text)
    return text.split()


def encode_prompt(example_chars, template_path):
    with open(template_path) as f:
        template = Template(f.read())
    for char in example_chars:
        char.pop("most_similar_chars", None)
        char.pop("avg_similarity_score", None)
    return template.render(
        example_chars=json.dumps(example_chars, ensure_ascii=False)
    ).strip() + "\n"


def post_process(response):
    if not response:
        return []
    if response["finish_reason"] == "length":
        return []
    raw_content = response["message"]["content"]
    try:
        chars = json.loads(raw_content)
        if isinstance(chars, list):
            return chars
        elif isinstance(chars, dict):
            return chars["characters"]
    except Exception:
        return []


def generate_chars(
    output_path: str,
    seed_chars_path: str,
    template_path: str,
    num_chars_to_generate: int = 200,
    model_name: str = "gpt-4",
    request_batch_size: int = 5,
    temperature: float = 1.0,
    top_p: float = 0.95,
    num_cpus: int = 8,
    rouge_cutoff: float = 0.24
):
    random.seed(43)
    seed_chars = [json.loads(line) for line in open(seed_chars_path, "r")]
    print(f"Loaded {len(seed_chars)} character examples")

    machine_chars = []
    if os.path.exists(output_path):
        machine_chars = read_jsonl(output_path)
        print(f"Loaded {len(machine_chars)} machine-generated characters")

    all_descriptions = [d["context"] for d in seed_chars + machine_chars]
    all_description_tokens = [tokenize(d) for d in all_descriptions]

    request_idx = 0
    progress_bar = tqdm.tqdm(total=num_chars_to_generate)
    if machine_chars:
        progress_bar.update(len(machine_chars))

    is_prompt_printed = False
    is_output_printed = False
    while len(machine_chars) < num_chars_to_generate:
        request_idx += 1

        batch = []
        for _ in range(request_batch_size):
            if machine_chars:
                prompt_chars = random.sample(machine_chars, 1)
                prompt_chars += random.sample(seed_chars, 1)
            else:
                prompt_chars = random.sample(seed_chars, 2)
            random.shuffle(prompt_chars)

            prompt = encode_prompt(prompt_chars, template_path)
            messages = [{"role": "user", "content": prompt}]
            batch.append(messages)

        if not is_prompt_printed:
            is_prompt_printed = True
            print("Prompt example:")
            for message in batch[0]:
                print("Role: {}, content: {}".format(message["role"], message["content"]))

        request_start = time.time()
        results = openai_batch_completion(
            batch=batch,
            model_name=model_name,
            decoding_args=OpenAIDecodingArguments(
                temperature=temperature,
                top_p=top_p
            )
        )
        if not is_output_printed:
            is_output_printed = True
            print("Output example:")
            print(results[0].message["content"])
        request_duration = time.time() - request_start

        process_start = time.time()
        new_chars = []
        for result in results:
            new_chars.extend(post_process(result))

        total = len(new_chars)
        keep = 0
        for new_char in new_chars:
            new_description_tokens = tokenize(new_char["context"])
            with Pool(num_cpus) as p:
                rouge_scores = p.map(
                    partial(rouge_scorer._score_lcs, new_description_tokens),
                    all_description_tokens,
                )
                rouge_scores = [score.fmeasure for score in rouge_scores]
            if max(rouge_scores) > rouge_cutoff:
                continue

            keep += 1
            machine_chars.append(new_char)
            all_descriptions.append(new_char["context"])
            all_description_tokens.append(new_description_tokens)
            progress_bar.update(1)

        process_duration = time.time() - process_start
        print(f"Request {request_idx} took {request_duration:.2f}s, processing took {process_duration:.2f}s")
        print(f"Generated {total} chars, kept {keep} chars")
        print("===================================")

        write_jsonl(machine_chars, output_path + "_tmp")
        shutil.move(output_path + "_tmp", output_path)


if __name__ == "__main__":
    fire.Fire(generate_chars)
